Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing down dynamic dimensions from torch to XLA #5790

Merged
merged 10 commits into from
Nov 15, 2023

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Nov 11, 2023

This is based on #5778. This PR focuses on supporting unbounded dynamism in LTC infra. The e2e would work when the unbounded dynamism support on XLA side is merged, and required op lowering logic is updated in torch_xla.

  • Add support in LTC to pass down the unbounded dynamism info down to HLO lowering.
  • Guard op lowering logic change for unbounded dynamism support by EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1
  • Fixing a bug in setup.py

Test:
Tested with @sdasgup3's XLA change to propagate unbounded dynamism branch. To make the patch works with torch_xla HEAD, the branch needs to be rebased to the pinned XLA version. Rebased branch

The unbounded dynamism can be propagated in the following example:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

a = torch.tensor([[1,2],[2,4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(a, 0)
b = torch.tensor([[1,2],[2,4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(b, 0)
c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
print(hlo_content)

which gives

loModule IrToHlo.5, entry_computation_layout={(s64[?,2]{1,0}, s64[?,2]{1,0})->(s64[?,2]{1,0})}

ENTRY %IrToHlo.5 (p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2]) {
  %p1.2 = s64[?,2]{1,0} parameter(1)
  %p0.1 = s64[?,2]{1,0} parameter(0)
  %multiply.3 = s64[?,2]{1,0} multiply(s64[?,2]{1,0} %p1.2, s64[?,2]{1,0} %p0.1)
  ROOT %tuple.4 = (s64[?,2]{1,0}) tuple(s64[?,2]{1,0} %multiply.3)
}

@lsy323 lsy323 marked this pull request as ready for review November 13, 2023 17:52
Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change the title of PR to something more descriptive: something like "Enable passing down dynamic dimensions from torch to XLA" etc.

Also all C++ use of macro change that to env var like the python one. (so that no need to recompile to enable the feature)

torch_xla/csrc/elementwise.cpp Outdated Show resolved Hide resolved
@lsy323 lsy323 changed the title Mirror of #5778 Make code compilable Enable passing down dynamic dimensions from torch to XLA Nov 13, 2023
@@ -322,6 +329,116 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
: xla::Reshape(input, shape.dimensions());
}

#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of macro, if you want to guard these functions you can add a check like XLA_CHECK(getenv(EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, will update

get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

torch_xla/csrc/ir.h Outdated Show resolved Hide resolved
Comment on lines 180 to 200
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we always need to run these or we should guard this in a conditional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's guard it by checking if the XlaNode has dynamic dim

@lsy323 lsy323 force-pushed the lsiyuan/sandeep-dynamic-shape branch from 671c981 to 23ee18c Compare November 13, 2023 21:52
@lsy323 lsy323 force-pushed the lsiyuan/sandeep-dynamic-shape branch from 48f931b to af14415 Compare November 13, 2023 21:55
torch_xla/csrc/helpers.cpp Outdated Show resolved Hide resolved
0, input_shape.element_type(), input.builder()));
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (experimental_unbounded_dynamism) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this path execute in the non export path?
Wdyt we limit this experimental condition to torch.export path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the discussion below, seems there is no better solution than using a env variable to enable the unbounded dynamism lowering. But we can put more thought on this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. This issue doesn't block this PR. Let's prepare a proposal in parallel.

[](int64_t size) { return size == kUnboundedSize; });
}

xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please prepare an opset issue to track supporting this mode of dynamism?

Similar work example: #5764

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think we have a list internally

Copy link
Collaborator

@miladm miladm Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work to maintain the list on GH?

@@ -31,6 +31,9 @@ struct SummationResult {
xla::XlaOp result;
};

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env variables causes poor user experience. What's the plan to clean up a better solution (potentially via the upstream torch API level)?

cc @JackCaoG @qihqi @lsy323

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for adding env variables, please keep this file up to date.

https://github.com/pytorch/xla/blob/master/configuration.yaml

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current thought is to use env variable to limit the code path to "Export Only", as you've pointed out in many places. If there are better mechanisms to accomplish that we can also use something different.

Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lsy323 - I left a few comments.

@lsy323 lsy323 force-pushed the lsiyuan/sandeep-dynamic-shape branch 2 times, most recently from f2e65aa to 5459950 Compare November 14, 2023 22:43
torch_xla/csrc/ir.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/ir.h Outdated Show resolved Hide resolved
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor nits

@lsy323 lsy323 force-pushed the lsiyuan/sandeep-dynamic-shape branch 2 times, most recently from e7401b3 to f8927e6 Compare November 15, 2023 04:19
@lsy323 lsy323 force-pushed the lsiyuan/sandeep-dynamic-shape branch from f8927e6 to 07d2b43 Compare November 15, 2023 04:25
Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

@lsy323 lsy323 merged commit a0a048e into master Nov 15, 2023
18 checks passed
@lsy323 lsy323 deleted the lsiyuan/sandeep-dynamic-shape branch November 16, 2023 02:03
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
zpcore pushed a commit that referenced this pull request Nov 21, 2023
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
lsy323 added a commit to lsy323/xla that referenced this pull request Nov 28, 2023
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
@sdasgup3 sdasgup3 added the dynamism Dynamic Shape Features label Jan 5, 2024
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamism Dynamic Shape Features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants